{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_TuAgGNKt5HO"
   },
   "source": [
    "# Tutorial: Exporting StableHLO from JAX\n",
    "\n",
    "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)][jax-tutorial-colab]\n",
    "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)][jax-tutorial-kaggle]\n",
    "\n",
    "JAX is a Python library for high-performance numerical computing. This tutorial shows how to export JAX and Flax (JAX-powered neural network library) models to StableHLO, and directly to TensorFlow SavedModel.\n",
    "\n",
    "## Tutorial Setup\n",
    "\n",
    "### Install required dependencies\n",
    "\n",
    "We use `jax` and `jaxlib` (JAX's support library with compiled binaries), along with `flax` and `transformers` for some models to export.\n",
    "We also need to install `tensorflow` to work with SavedModel, and recommend using `tensorflow-cpu` or `tf-nightly` for this tutorial.\n",
    "\n",
    "[jax-tutorial-colab]: https://colab.research.google.com/github/openxla/stablehlo/blob/main/docs/tutorials/jax-export.ipynb\n",
    "[jax-tutorial-kaggle]: https://kaggle.com/kernels/welcome?src=https://github.com/openxla/stablehlo/blob/main/docs/tutorials/jax-export.ipynb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "id": "ENUcO6aML-Nq",
    "jupyter": {
     "outputs_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "!pip install -U jax jaxlib flax transformers tensorflow-cpu"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "cellView": "form",
    "id": "HqjeC_QSugYj"
   },
   "outputs": [],
   "source": [
    "#@title Define `get_stablehlo_asm` to help with MLIR printing\n",
    "from jax._src.interpreters import mlir as jax_mlir\n",
    "from jax._src.lib.mlir import ir\n",
    "\n",
    "# Returns prettyprint of StableHLO module without large constants\n",
    "def get_stablehlo_asm(module_str):\n",
    "  with jax_mlir.make_ir_context():\n",
    "    stablehlo_module = ir.Module.parse(module_str, context=jax_mlir.make_ir_context())\n",
    "    return stablehlo_module.operation.get_asm(large_elements_limit=20)\n",
    "\n",
    "# Disable logging for better tutorial rendering\n",
    "import logging\n",
    "logging.disable(logging.WARNING)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "LNlFj80UwX0D"
   },
   "source": [
    "_Note: This helper uses a JAX internal API that may break at any time, but it serves no functional purpose in the tutorial aside from readability._"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "TEfsW69IBp_V"
   },
   "source": [
    "## Export JAX model to StableHLO using `jax.export`\n",
    "\n",
    "In this section we'll export a basic JAX function and a Flax model to StableHLO.\n",
    "\n",
    "The preferred API for export is [`jax.export`](https://jax.readthedocs.io/en/latest/jax.export.html#module-jax.export). The function to export must be JIT transformed, specifically a result of `jax.jit`, to be exported to StableHLO."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3MzMjjf2iIk2"
   },
   "source": [
    "### Export basic JAX model to StableHLO\n",
    "\n",
    "Let's start by exporting a basic `plus` function to StableHLO, using `np.int32` argument types to trace the function.\n",
    "\n",
    "Export requires specifying shapes using `jax.ShapeDtypeStruct`, which can be constructed from NumPy values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "v-GN3vPbvFoa",
    "outputId": "b128c08f-3591-4142-fd67-561812bb3d4e"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "module @jit_plus attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {\n",
      "  func.func public @main(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32> {jax.result_info = \"\"}) {\n",
      "    %0 = stablehlo.add %arg0, %arg1 : tensor<i32>\n",
      "    return %0 : tensor<i32>\n",
      "  }\n",
      "}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import jax\n",
    "from jax import export\n",
    "import jax.numpy as jnp\n",
    "import numpy as np\n",
    "\n",
    "# Create a JIT-transformed function\n",
    "@jax.jit\n",
    "def plus(x,y):\n",
    "  return jnp.add(x,y)\n",
    "\n",
    "# Create abstract input shapes\n",
    "inputs = (np.int32(1), np.int32(1),)\n",
    "input_shapes = [jax.ShapeDtypeStruct(input.shape, input.dtype) for input in inputs]\n",
    "\n",
    "# Export the function to StableHLO\n",
    "stablehlo_add = export.export(plus)(*input_shapes).mlir_module()\n",
    "print(get_stablehlo_asm(stablehlo_add))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xZWVAHQzBsEM"
   },
   "source": [
    "### Export Hugging Face FlaxResNet18 to StableHLO\n",
    "\n",
    "Now let's look at a simple model that appears in the wild, `resnet18`.\n",
    "\n",
    "We'll export a `flax` model from the Hugging Face `transformers` ResNet page, [FlaxResNetModel](https://huggingface.co/docs/transformers/en/model_doc/resnet#transformers.FlaxResNetModel). This steps setup was copied from the Hugging Face documentation.\n",
    "\n",
    "The documentation also states: _\"Finally, this model supports inherent JAX features such as: **Just-In-Time (JIT) compilation** ...\"_ which means it is perfect for export.\n",
    "\n",
    "Similar to our very basic example, our steps for export are:\n",
    "\n",
    "1. Instantiate a callable (model/function) \n",
    "2. JIT-transform it with `jax.jit`\n",
    "3. Specify shapes for export using `jax.ShapeDtypeStruct` on NumPy values\n",
    "4. Use the JAX `export` API to get a StableHLO module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "53T7jO-v_6PC",
    "outputId": "0536386d-09d2-49a3-b51c-951c20f6e49b"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "module @jit__unnamed_wrapped_function_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {\n",
      "  func.func public @main(%arg0: tensor<1x3x224x224xf32>) -> (tensor<1x512x7x7xf32> {jax.result_info = \"[0]\"}, tensor<1x512x1x1xf32> {jax.result_info = \"[1]\"}) {\n",
      "    %c = stablehlo.constant dense<49> : tensor<i32>\n",
      "    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>\n",
      "    %cst_0 = stablehlo.constant dense<0xFF800000> : tensor<f32>\n",
      "    %cst_1 = stablehlo.constant dense<9.99999974E-6> : tensor<f32>\n",
      "    %cst_2 = stablehlo.constant dense_reso \n",
      "...\n",
      " func.func private @relu_3(%arg0: tensor<1x7x7x512xf32>) -> tensor<1x7x7x512xf32> {\n",
      "    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>\n",
      "    %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<1x7x7x512xf32>\n",
      "    %1 = stablehlo.maximum %arg0, %0 : tensor<1x7x7x512xf32>\n",
      "    return %1 : tensor<1x7x7x512xf32>\n",
      "  }\n",
      "}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoImageProcessor, FlaxResNetModel\n",
    "import jax\n",
    "import numpy as np\n",
    "\n",
    "# Construct jit-transformed flax model with sample inputs\n",
    "resnet18 = FlaxResNetModel.from_pretrained(\"microsoft/resnet-18\", return_dict=False)\n",
    "resnet18_jit = jax.jit(resnet18)\n",
    "sample_input = np.random.randn(1, 3, 224, 224)\n",
    "input_shape = jax.ShapeDtypeStruct(sample_input.shape, sample_input.dtype)\n",
    "\n",
    "# Export to StableHLO\n",
    "stablehlo_resnet18_export = export.export(resnet18_jit)(input_shape)\n",
    "resnet18_stablehlo = get_stablehlo_asm(stablehlo_resnet18_export.mlir_module())\n",
    "print(resnet18_stablehlo[:600], \"\\n...\\n\", resnet18_stablehlo[-345:])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "X2MC9F7Zlx6E"
   },
   "source": [
    "### Export with dynamic batch size\n",
    "\n",
    "Now let's export that same model with a dynamic batch size!\n",
    "\n",
    "In the first example, we used an input shape of `tensor<1x3x224x224xf32>`, specifying strict constraints on the input shape. If we want to defer the concrete shapes used in compilation until a later point, we can specify a `symbolic_shape`. In this example, we'll export using `tensor<?x3x224x224xf32>`.\n",
    "\n",
    "Symbolic shapes are specified using `export.symbolic_shape`, with letters representing symint dimensions. For example, a valid 2-d matrix multiplication could use symbolic constraints of: `2,a * a,5` to ensure the refined program will have valid shapes. Symbolic integer names are kept track of by an `export.SymbolicScope` to avoid unintentional name clashes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "sIkbtViEMJ3T",
    "outputId": "165019c7-e771-43d6-c324-6bb4222798ec"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "module @jit__unnamed_wrapped_function_ attributes {jax.uses_shape_polymorphism = true, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {\n",
      "  func.func public @main(%arg0: tensor<?x3x224x224xf32>) -> (tensor<?x512x7x7xf32> {jax.result_info = \"[0]\"}, tensor<?x512x1x1xf32> {jax.result_info = \"[1]\"}) {\n",
      "    %c = stablehlo.constant dense<1> : tensor<i32>\n",
      "    %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?x3x224x224xf32>) -> tensor<i32>\n",
      "    %1 = stablehlo.compare  GE, %0, %c,  SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>\n",
      "    stablehlo.custom_call @shape_assertion(%1, %0) {api_version = 2 : i32, error_message = \"Input shapes do not match the polymorphic shapes specification. Expected value >= 1 for dimension variable 'a'. Using the following polymorphic shapes specifications: args[0].shape = (a, 3, 224, 224). Obtained dimension variables: 'a' = {0} from specification 'a' for dimension args[0].shape[0] (= {0}), . Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details.\", has_side_effect = true} : (tensor<i1>, tensor<i32>) -> ()\n",
      "    %2:2 = call @_wrapped_jax_export_main(%0, %arg0) : (tensor<i32>, tensor<?x3x224x224xf32>) -> (tensor<?x512x7x7xf32>, tensor<?x512x1x1xf32>)\n",
      "    return %2#0, %2#1 : tensor<?x512x7x7xf32>, tensor<?x512x1x1xf32>\n",
      "  }\n",
      "  func.func private @_wrapped_jax_export_main(%arg0: tensor<i32> {jax.global_constant = \"a\"}, %arg1: tensor<?x3x224x224xf32>) -> (tensor<?x512x7x7xf32> {jax.result_info = \"[0]\"}, tensor<?x512x1x1xf32> {jax.result_info = \"[1]\"}) {\n",
      "    %c = stablehlo.constant dense<1> : tensor<1xi32>\n",
      "    %c_0 = stablehlo.constant dense<49> : tensor<i32>\n",
      "    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>\n",
      "    %c_1 = stablehlo.constant dense<512> : tensor<1xi32>\n",
      "    %c_2 = stablehlo.constant dense<7> : tensor<1xi32>\n",
      "    %c_3 = stablehlo.constant dense<256> : tensor<1x \n",
      "...\n",
      " , tensor<1xi32>) -> tensor<4xi32>\n",
      "    %2 = stablehlo.dynamic_broadcast_in_dim %cst, %1, dims = [] : (tensor<f32>, tensor<4xi32>) -> tensor<?x14x14x256xf32>\n",
      "    %3 = stablehlo.maximum %arg1, %2 : tensor<?x14x14x256xf32>\n",
      "    return %3 : tensor<?x14x14x256xf32>\n",
      "  }\n",
      "  func.func private @relu_3(%arg0: tensor<i32> {jax.global_constant = \"a\"}, %arg1: tensor<?x7x7x512xf32>) -> tensor<?x7x7x512xf32> {\n",
      "    %c = stablehlo.constant dense<512> : tensor<1xi32>\n",
      "    %c_0 = stablehlo.constant dense<7> : tensor<1xi32>\n",
      "    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>\n",
      "    %0 = stablehlo.reshape %arg0 : (tensor<i32>) -> tensor<1xi32>\n",
      "    %1 = stablehlo.concatenate %0, %c_0, %c_0, %c, dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32>\n",
      "    %2 = stablehlo.dynamic_broadcast_in_dim %cst, %1, dims = [] : (tensor<f32>, tensor<4xi32>) -> tensor<?x7x7x512xf32>\n",
      "    %3 = stablehlo.maximum %arg1, %2 : tensor<?x7x7x512xf32>\n",
      "    return %3 : tensor<?x7x7x512xf32>\n",
      "  }\n",
      "}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Construct dynamic sample inputs\n",
    "dyn_scope = export.SymbolicScope()\n",
    "dyn_input_shape = jax.ShapeDtypeStruct(export.symbolic_shape(\"a,3,224,224\", scope=dyn_scope), np.float32)\n",
    "\n",
    "# Export to StableHLO\n",
    "dyn_resnet18_export = export.export(resnet18_jit)(dyn_input_shape)\n",
    "dyn_resnet18_stablehlo = get_stablehlo_asm(dyn_resnet18_export.mlir_module())\n",
    "print(dyn_resnet18_stablehlo[:1900], \"\\n...\\n\", dyn_resnet18_stablehlo[-1000:])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dRIu7xlSoDUK"
   },
   "source": [
    "A few things to note in the exported StableHLO:\n",
    "\n",
    "1. The exported program now has `tensor<?x3x224x224xf32>`. These input types can be refined in many ways: StableHLO has APIs to [refine shapes](https://github.com/openxla/stablehlo/blob/541db997e449dcfee8536043dfdd49bb13f9ed1a/stablehlo/transforms/Passes.td#L69-L99) and [canonicalize dynamic programs](https://github.com/openxla/stablehlo/blob/541db997e449dcfee8536043dfdd49bb13f9ed1a/stablehlo/transforms/Passes.td#L18-L28) to static programs. TensorFlow SavedModel execution also takes care of refinement which we'll see in the next example.\n",
    "2. JAX will generate guards to ensure the values of `a` are valid, in this case `a > 1` is checked. These can be washed away at compile time once refined."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xFU5M6Xm1U8_"
   },
   "source": [
    "## Export to TensorFlow SavedModel\n",
    "\n",
    "It is common to export a StableHLO model to SavedModel for interoperability with existing compilation pipelines, existing TensorFlow tooling, or serving via [TensorFlow Serving](https://github.com/tensorflow/serving).\n",
    "\n",
    "JAX makes it easy to pack StableHLO into a SavedModel, and load that SavedModel in the future. For this section, we'll be using our dynamic model from the previous section."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "lf7Fnsrop7BD"
   },
   "source": [
    "### Export to SavedModel using `jax2tf`\n",
    "\n",
    "JAX provides a simple API for exporting StableHLO into a format that can be packaged in SavedModel in `jax.experimental.jax2tf`. This uses the `export` function under the hood, so the same `jit` requirements apply.\n",
    "\n",
    "Full details on `jax2tf` can be found in the [README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#jax-and-tensorflow-interoperation-jax2tfcall_tf). For this example, we'll only need to know the `polymorphic_shapes` option to specify our dynamic batch dimension."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "GXkgtX7QEiWa",
    "outputId": "9034836e-4ba1-4210-c2c1-316777d4ad89"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[34massets\u001b[m\u001b[m         fingerprint.pb saved_model.pb \u001b[34mvariables\u001b[m\u001b[m\n"
     ]
    }
   ],
   "source": [
    "from jax.experimental import jax2tf\n",
    "import tensorflow as tf\n",
    "\n",
    "exported_f = jax2tf.convert(resnet18, polymorphic_shapes=[\"(a,3,224,224)\"])\n",
    "\n",
    "# Copied from the jax2tf README.md > Usage: saved model\n",
    "my_model = tf.Module()\n",
    "my_model.f = tf.function(exported_f, autograph=False).get_concrete_function(tf.TensorSpec([None, 3, 224, 224], tf.float32))\n",
    "tf.saved_model.save(my_model, '/tmp/resnet18_tf', options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))\n",
    "\n",
    "!ls /tmp/resnet18_tf"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "CmKABmLdrS3C"
   },
   "source": [
    "### Reload and call the SavedModel\n",
    "\n",
    "Now we can load that SavedModel and compile using our `sample_input` from a previous example.\n",
    "\n",
    "_Note: The restored model does *not* require JAX to run, just XLA._"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "9Az3dXXWrVDM",
    "outputId": "cac6d3b1-126e-4e66-dc4d-f2a798ede463"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Result shape: (1, 512, 7, 7)\n"
     ]
    }
   ],
   "source": [
    "restored_model = tf.saved_model.load('/tmp/resnet18_tf')\n",
    "restored_result = restored_model.f(tf.constant(sample_input, tf.float32))\n",
    "print(\"Result shape:\", restored_result[0].shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "a2Dsm2oF5jn4"
   },
   "source": [
    "## Troubleshooting\n",
    "\n",
    "### `jax.jit` issues\n",
    "\n",
    "If the function can be JIT'ed, then it can be exported. Ensure `jax.jit` works first, or look in desired project for uses of JIT already (for example, [AlphaFold's `apply`](https://github.com/google-deepmind/alphafold/blob/dbe2a438ebfc6289f960292f15dbf421a05e563d/alphafold/model/model.py#L89) can be exported easily). \n",
    "\n",
    "See [JAX's JIT compilation documentation](https://jax.readthedocs.io/en/latest/jit-compilation.html) and [`jax.jit` API reference and examples](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) for troubleshooting JIT transformations. The most common issue is control flow, which can often be resolved with `static_argnums` / `static_argnames` as in the linked example.\n",
    "\n",
    "### Support tickets\n",
    "\n",
    "You can open an issue on GitHub for further help. Include a reproducible example using one of the above APIs in your issue report, this will help get the issue resolved much quicker!"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
